/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/* !
 * \file kernel_operator_vec_gather_impl.h
 * \brief
 */
#ifndef ASCENDC_MODULE_OPERATOR_VEC_GATHER_IMPL_H
#define ASCENDC_MODULE_OPERATOR_VEC_GATHER_IMPL_H
#include "kernel_utils.h"
#include "kernel_operator_common_impl.h"
#include "kernel_operator_vec_template_impl.h"

namespace AscendC {
/* **************************************************************************************************
 * Gather                                             *
 * ************************************************************************************************* */
template <typename T, bool isNormalMode>
__simd_vf__ inline void VfGatherApi0B16(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint8_t repeatTime, const uint16_t dstRepStride, uint32_t dstRepeatCount,
    uint32_t u32OffsetRepeatCount, uint32_t blkCount, const uint64_t maskCount)
{
    MicroAPI::RegTensor<uint32_t> offsetReg0;
    MicroAPI::RegTensor<uint32_t> offsetReg1;
    MicroAPI::RegTensor<uint16_t> indexReg;
    MicroAPI::RegTensor<uint16_t> dstReg;
    uint32_t sregPlt = static_cast<uint32_t>(maskCount);
    MicroAPI::MaskReg indexMask = MicroAPI::CreateMask<uint32_t>();
    MicroAPI::MaskReg selectMask = MicroAPI::CreateMask<uint16_t, MicroAPI::MaskPattern::H>();
    MicroAPI::MaskReg dstMask;
    if constexpr (isNormalMode) {
        dstMask = MicroAPI::MoveMask<T>();
    }
    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        if constexpr (!isNormalMode) {
            dstMask = MicroAPI::UpdateMask<T>(sregPlt);
        }
        DataCopy(offsetReg0, srcOffsetLocal + (2 * i) * u32OffsetRepeatCount);
        DataCopy(offsetReg1, srcOffsetLocal + (2 * i + 1) * u32OffsetRepeatCount);
        // convert addr offset into B16 element index: divide by 2 (implemented by ShiftRight 1 bit)
        ShiftRights(offsetReg0, offsetReg0, (int16_t)1, indexMask);
        ShiftRights(offsetReg1, offsetReg1, (int16_t)1, indexMask);
        // extract the lower 16-bit of uint32_t offset data into uint16_t index data:
        // for offsetReg0，pack every lower 16-bit into the lower half of the vreg：
        // 0x00FF00FE00FD... ->0xFFFEFD...000000...
        // for offsetReg1, pack every higher 16-bit into the higher half of the vreg:
        // 0x001100120013... -> 0x000000...111213...
        MicroAPI::Pack<uint16_t, uint32_t, MicroAPI::HighLowPart::LOWEST>((MicroAPI::RegTensor<uint16_t> &)offsetReg0,
            offsetReg0);
        MicroAPI::Pack<uint16_t, uint32_t, MicroAPI::HighLowPart::HIGHEST>((MicroAPI::RegTensor<uint16_t> &)offsetReg1,
            offsetReg1);
        // Select the effective data in offsetReg0 and offsetReg1 and joint them into a complete uint16_t type
        // indexReg：0xFFFEFD...111213...
        Select(indexReg, (MicroAPI::RegTensor<uint16_t> &)offsetReg0, (MicroAPI::RegTensor<uint16_t> &)offsetReg1,
            selectMask);
        DataCopyGather(dstReg, (__ubuf__ uint16_t *)src + srcBaseIndex, indexReg, dstMask);
        DataCopy((__ubuf__ uint16_t *)dst + i * dstRepStride * blkCount, dstReg, dstMask);
    }
}

template <typename T, bool isNormalMode>
__aicore__ inline void GatherApi0B16Impl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint8_t repeatTime, const uint16_t &dstRepStride, const uint64_t maskCount)
{
    uint32_t dstRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(T));
    uint32_t u32OffsetRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(uint32_t));
    uint32_t blkCount = static_cast<uint32_t>(ONE_BLK_SIZE / sizeof(T));
    VfGatherApi0B16<T, isNormalMode>(dst, src, srcOffsetLocal, srcBaseIndex, repeatTime, dstRepStride,
        dstRepeatCount, u32OffsetRepeatCount, blkCount, maskCount);
}

template <typename T, bool isNormalMode>
__simd_vf__ inline void VfGatherApi0B32(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint8_t repeatTime, const uint16_t dstRepStride, uint32_t dstRepeatCount,
    uint32_t u32OffsetRepeatCount, uint32_t blkCount, const uint64_t maskCount)
{
    MicroAPI::RegTensor<uint32_t> offsetReg;
    MicroAPI::RegTensor<uint32_t> indexReg;
    MicroAPI::RegTensor<uint32_t> dstReg;
    uint32_t sregPlt = static_cast<uint32_t>(maskCount);
    MicroAPI::MaskReg indexMask = MicroAPI::CreateMask<T>();
    MicroAPI::MaskReg dstMask;
    MicroAPI::MaskReg offsetMask = MicroAPI::CreateMask<uint32_t>();
    if constexpr (isNormalMode) {
        dstMask = MicroAPI::MoveMask<T>();
    }
    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        if constexpr (!isNormalMode) {
            dstMask = MicroAPI::UpdateMask<T>(sregPlt);
        }
        DataCopy(offsetReg, srcOffsetLocal + i * u32OffsetRepeatCount);
        // convert addr offset into B32 element index: divide by 4 (implemented by ShiftRight 2 bit)
        ShiftRights(indexReg, offsetReg, (int16_t)2, indexMask);
        DataCopyGather(dstReg, (__ubuf__ uint32_t *)src + srcBaseIndex, indexReg, dstMask);
        DataCopy((__ubuf__ uint32_t *)dst + i * dstRepStride * blkCount, dstReg, dstMask);
    }
}

template <typename T, bool isNormalMode>
__aicore__ inline void GatherApi0B32Impl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint8_t repeatTime, const uint16_t &dstRepStride, const uint64_t maskCount)
{
    uint32_t dstRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(T));
    uint32_t u32OffsetRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(uint32_t));
    uint32_t blkCount = static_cast<uint32_t>(ONE_BLK_SIZE / sizeof(T));
    VfGatherApi0B32<T, isNormalMode>(dst, src, srcOffsetLocal, srcBaseIndex, repeatTime, dstRepStride,
        dstRepeatCount, u32OffsetRepeatCount, blkCount, maskCount);
}

template <typename T, bool isNormalMode>
__simd_vf__ inline void VfGatherApi0B64Normal(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint8_t repeatTime, const uint16_t dstRepStride, uint32_t u64OffsetRepeatCount,
    uint32_t u32BlkCount, const uint64_t maskCount)
{
    MicroAPI::RegTensor<uint32_t> offsetReg;
    MicroAPI::RegTensor<uint32_t> indexReg;
    MicroAPI::RegTensor<uint32_t> oddIndexReg;
    MicroAPI::RegTensor<uint32_t> tmpReg;
    MicroAPI::RegTensor<uint32_t> indexU32Reg;
    MicroAPI::RegTensor<uint32_t> dstReg;
    uint32_t sregPlt = static_cast<uint32_t>(maskCount);
    MicroAPI::MaskReg indexMask = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::H>();
    MicroAPI::MaskReg dstMask;
    if constexpr (isNormalMode) {
        dstMask = MicroAPI::UpdateMask<uint32_t>(sregPlt);
    }
    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        if constexpr (!isNormalMode) {
            dstMask = MicroAPI::UpdateMask<uint32_t>(sregPlt);
        }
        DataCopy(offsetReg, srcOffsetLocal + i * u64OffsetRepeatCount);
        // convert addr offset into B64 element index: divide by 8 (implemented by ShiftRight 3 bit)
        ShiftRights(indexReg, offsetReg, (int16_t)3, indexMask);
        // Consider every B64 element as two B32 elements
        // the index of lower 32-bit is 2 * i, the index of higher 32-bit is 2 * i + 1
        Muls(indexReg, indexReg, 2, indexMask);
        Adds(oddIndexReg, indexReg, 1, indexMask);
        // Interleave the seperately calculated indices of the lower and higher 32-bit of every B64 element
        Interleave(indexU32Reg, tmpReg, indexReg, oddIndexReg);
        DataCopyGather(dstReg, (__ubuf__ uint32_t *)src + srcBaseIndex, indexU32Reg, dstMask);
        DataCopy((__ubuf__ uint32_t *)dst + i * dstRepStride * u32BlkCount, dstReg, dstMask);
    }
}

template <typename T, bool isNormalMode>
__aicore__ inline void GatherApi0B64NormalImpl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint8_t repeatTime, const uint16_t &dstRepStride, const uint64_t maskCount)
{
    uint64_t dstMaskValue = maskCount * 2;
    uint32_t u64OffsetRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(T));
    uint32_t u32BlkCount = static_cast<uint32_t>(ONE_BLK_SIZE / sizeof(uint32_t));
    VfGatherApi0B64Normal<T, isNormalMode>(dst, src, srcOffsetLocal, srcBaseIndex, repeatTime, dstRepStride,
        u64OffsetRepeatCount, u32BlkCount, dstMaskValue);
}

template <typename T, bool isNormalMode>
__simd_vf__ inline void VfGatherApi0B64Bits(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint8_t repeatTime, const uint16_t dstRepStride, uint32_t u64OffsetRepeatCount,
    uint32_t u32BlkCount, const uint64_t maskCount)
{
    MicroAPI::RegTensor<uint32_t> offsetReg;
    MicroAPI::RegTensor<uint32_t> bitsIndexReg;
    MicroAPI::RegTensor<uint32_t> bitsOddIndexReg;
    MicroAPI::RegTensor<uint32_t> tmpReg;
    MicroAPI::RegTensor<uint32_t> indexU32Reg;
    MicroAPI::RegTensor<uint32_t> dstReg;
    uint32_t sregPlt = static_cast<uint32_t>(maskCount);
    MicroAPI::MaskReg dstMask;
    MicroAPI::MaskReg indexMask = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::H>();
    if constexpr (isNormalMode) {
        MicroAPI::MaskReg tmpMask0;
        MicroAPI::MaskReg tmpMask1 = MicroAPI::MoveMask<uint16_t>();
        MicroAPI::MaskPack(tmpMask1, tmpMask1);
        MicroAPI::MaskInterleave<uint8_t>(dstMask, tmpMask0, tmpMask1, tmpMask1);
        MicroAPI::MaskUnPack(dstMask, dstMask);
        MicroAPI::MaskUnPack(dstMask, dstMask);
    }

    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        if constexpr (!isNormalMode) {
            dstMask = MicroAPI::UpdateMask<uint32_t>(sregPlt);
        }
        DataCopy(offsetReg, srcOffsetLocal + i * u64OffsetRepeatCount);
        // convert addr offset into B64 element index: divide by 8 (implemented by ShiftRight 3 bit)
        ShiftRights(bitsIndexReg, offsetReg, (int16_t)3, indexMask);
        // Consider every B64 element as two B32 elements
        // the index of lower 32-bit is 2 * i, the index of higher 32-bit is 2 * i + 1
        Muls(bitsIndexReg, bitsIndexReg, 2, indexMask);
        Adds(bitsOddIndexReg, bitsIndexReg, 1, indexMask);
        // Interleave the seperately calculated indices of the lower and higher 32-bit of every B64 element
        Interleave(indexU32Reg, tmpReg, bitsIndexReg, bitsOddIndexReg);
        DataCopyGather(dstReg, (__ubuf__ uint32_t *)src + srcBaseIndex, indexU32Reg, dstMask);
        DataCopy((__ubuf__ uint32_t *)dst + i * dstRepStride * u32BlkCount, dstReg, dstMask);
    }
}

template <typename T, bool isNormalMode>
__aicore__ inline void GatherApi0B64BitsImpl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint8_t repeatTime, const uint16_t &dstRepStride, const uint64_t maskCount)
{
    uint32_t u64OffsetRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(T));
    uint32_t u32BlkCount = static_cast<uint32_t>(ONE_BLK_SIZE / sizeof(uint32_t));
    uint64_t dstMaskValue = maskCount * 2;
    VfGatherApi0B64Bits<T, isNormalMode>(dst, src, srcOffsetLocal, srcBaseIndex, repeatTime, dstRepStride,
        u64OffsetRepeatCount, u32BlkCount, dstMaskValue);
}

template <typename T>
__simd_vf__ inline void VfGatherApi2B8(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint32_t count, uint32_t u32OffsetRepeatCount, uint32_t u8GatherRepeatCount)
{
    MicroAPI::RegTensor<uint32_t> offsetReg0;
    MicroAPI::RegTensor<uint32_t> offsetReg1;
    MicroAPI::RegTensor<uint16_t> indexReg;
    MicroAPI::RegTensor<uint16_t> dstReg;
    uint32_t sreg = static_cast<uint32_t>(count);
    MicroAPI::MaskReg indexMask = MicroAPI::CreateMask<uint16_t>();
    MicroAPI::MaskReg dstMask;
    uint16_t repeatTime = CeilDivision(count, u8GatherRepeatCount);
    MicroAPI::MaskReg selectMask = MicroAPI::CreateMask<uint16_t, MicroAPI::MaskPattern::H>();
    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        dstMask = MicroAPI::UpdateMask<uint16_t>(sreg);
        DataCopy(offsetReg0, srcOffsetLocal + (2 * i) * u32OffsetRepeatCount);
        DataCopy(offsetReg1, srcOffsetLocal + (2 * i + 1) * u32OffsetRepeatCount);
        // extract the lower 16-bit of uint32_t offset data into uint16_t index data:
        // for offsetReg0，pack every lower 16-bit into the lower half of the vreg：
        // 0x00FF00FE00FD... -> 0xFFFEFD...000000...
        // for offsetReg1, pack every higher 16-bit into the higher half of the vreg:
        // 0x001100120013... -> 0x000000...111213...
        MicroAPI::Pack<uint16_t, uint32_t, MicroAPI::HighLowPart::LOWEST>((MicroAPI::RegTensor<uint16_t> &)offsetReg0,
            offsetReg0);
        MicroAPI::Pack<uint16_t, uint32_t, MicroAPI::HighLowPart::HIGHEST>((MicroAPI::RegTensor<uint16_t> &)offsetReg1,
            offsetReg1);
        // Select the effective data in offsetReg0 and offsetReg1 and joint them into a complete uint16_t type indexReg:
        // 0xFFFEFD...111213...
        Select(indexReg, (MicroAPI::RegTensor<uint16_t> &)offsetReg0, (MicroAPI::RegTensor<uint16_t> &)offsetReg1,
            selectMask);
        DataCopyGather(dstReg, (__ubuf__ uint8_t *)src + srcBaseIndex, indexReg, dstMask);
        // remove the higher zeros of the uint16_t data gathered by the Micro Gather instr, and pack into continuous B8
        // data: 0x010203... -> 0x123...000... (only the lower 128 elements are effective)
        MicroAPI::Pack((MicroAPI::RegTensor<uint8_t> &)dstReg, dstReg);
        // convert uint16_t type preg to B8 type preg：0b010101... -> 0b111...000... (lower 128-bit effective)
        MicroAPI::MaskPack(dstMask, dstMask);
        DataCopy((__ubuf__ uint8_t *)dst + i * u8GatherRepeatCount, (MicroAPI::RegTensor<uint8_t> &)dstReg, dstMask);
    }
}

template <typename T>
__aicore__ inline void GatherApi2B8Impl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint32_t count)
{
    uint32_t u8GatherRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(uint16_t));
    uint32_t u32OffsetRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(uint32_t));
    VfGatherApi2B8<T>(dst, src, srcOffsetLocal, srcBaseIndex, count, u32OffsetRepeatCount,
        u8GatherRepeatCount);
}

template <typename T>
__simd_vf__ inline void VfGatherApi2B16(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint32_t count, uint32_t dstRepeatCount, uint32_t u32OffsetRepeatCount,
    uint16_t repeatTime)
{
    MicroAPI::RegTensor<uint32_t> api2OffsetReg0;
    MicroAPI::RegTensor<uint32_t> api2OffsetReg1;
    MicroAPI::RegTensor<uint16_t> indexReg;
    MicroAPI::RegTensor<uint16_t> dstReg;
    MicroAPI::MaskReg indexMask = MicroAPI::CreateMask<uint16_t>();
    MicroAPI::MaskReg dstMask;
    MicroAPI::MaskReg selectMask = MicroAPI::CreateMask<uint16_t, MicroAPI::MaskPattern::H>();
    uint32_t sreg = static_cast<uint32_t>(count);
    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        dstMask = MicroAPI::UpdateMask<uint16_t>(sreg);
        DataCopy(api2OffsetReg0, srcOffsetLocal + (2 * i) * u32OffsetRepeatCount);
        DataCopy(api2OffsetReg1, srcOffsetLocal + (2 * i + 1) * u32OffsetRepeatCount);
        // convert addr offset into B32 element index: divide by 4 (implemented by ShiftRight 2 bit)
        ShiftRights(api2OffsetReg0, api2OffsetReg0, (int16_t)1, indexMask);
        ShiftRights(api2OffsetReg1, api2OffsetReg1, (int16_t)1, indexMask);
        // extract the lower 16-bit of uint32_t offset data into uint16_t index data:
        // for api2OffsetReg0，pack every lower 16-bit into the lower half of the vreg：
        // 0x00FF00FE00FD... -> 0xFFFEFD...000000...
        // for api2OffsetReg1, pack every higher 16-bit into the higher half of the vreg:
        // 0x001100120013... -> 0x000000...111213...
        MicroAPI::Pack<uint16_t, uint32_t, MicroAPI::HighLowPart::LOWEST>((MicroAPI::RegTensor<uint16_t> &)api2OffsetReg0,
            api2OffsetReg0);
        MicroAPI::Pack<uint16_t, uint32_t, MicroAPI::HighLowPart::HIGHEST>((MicroAPI::RegTensor<uint16_t> &)api2OffsetReg1,
            api2OffsetReg1);
        // Select the effective data in api2OffsetReg0 and api2OffsetReg1 and joint them into a complete uint16_t type
        // indexReg：0xFFFEFD...111213...
        Select(indexReg, (MicroAPI::RegTensor<uint16_t> &)api2OffsetReg0, (MicroAPI::RegTensor<uint16_t> &)api2OffsetReg1,
            selectMask);
        DataCopyGather(dstReg, (__ubuf__ uint16_t *)src + srcBaseIndex, indexReg, dstMask);
        DataCopy((__ubuf__ uint16_t *)dst + i * dstRepeatCount, dstReg, dstMask);
    }
}

template <typename T>
__aicore__ inline void GatherApi2B16Impl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint32_t count)
{
    uint32_t dstRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(T));
    uint32_t u32OffsetRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(uint32_t));
    uint16_t repeatTime = CeilDivision(count, dstRepeatCount);
    VfGatherApi2B16<T>(dst, src, srcOffsetLocal, srcBaseIndex, count, dstRepeatCount, u32OffsetRepeatCount,
        repeatTime);
}

template <typename T>
__simd_vf__ inline void VfGatherApi2B32(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint32_t count, uint32_t dstRepeatCount, uint32_t u32OffsetRepeatCount,
    uint16_t repeatTime)
{
    MicroAPI::RegTensor<uint32_t> offsetReg;
    MicroAPI::RegTensor<uint32_t> indexReg;
    MicroAPI::RegTensor<uint32_t> dstReg;
    MicroAPI::MaskReg indexMask = MicroAPI::CreateMask<T>();
    MicroAPI::MaskReg dstMask;
    uint32_t sreg = static_cast<uint32_t>(count);
    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        dstMask = MicroAPI::UpdateMask<uint32_t>(sreg);
        DataCopy(offsetReg, srcOffsetLocal + i * u32OffsetRepeatCount);
        // convert addr offset into B32 element index: divide by 4 (implemented by ShiftRight 2 bit)
        ShiftRights(indexReg, offsetReg, (int16_t)2, indexMask);
        DataCopyGather(dstReg, (__ubuf__ uint32_t *)src + srcBaseIndex, indexReg, dstMask);
        DataCopy((__ubuf__ uint32_t *)dst + i * dstRepeatCount, dstReg, dstMask);
    }
}

template <typename T>
__aicore__ inline void GatherApi2B32Impl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint32_t count)
{
    uint32_t dstRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(T));
    uint32_t u32OffsetRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(uint32_t));
    uint16_t repeatTime = CeilDivision(count, dstRepeatCount);
    VfGatherApi2B32<T>(dst, src, srcOffsetLocal, srcBaseIndex, count, dstRepeatCount, u32OffsetRepeatCount,
        repeatTime);
}

template <typename T>
__simd_vf__ inline void VfGatherApi2B64(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint32_t count, uint32_t u32Count, uint32_t u64OffsetRepeatCount,
    uint32_t u32RepeatCount, uint16_t repeatTime)
{
    MicroAPI::RegTensor<uint32_t> offsetReg;
    MicroAPI::RegTensor<uint32_t> indexReg;
    MicroAPI::RegTensor<uint32_t> oddIndexReg;
    MicroAPI::RegTensor<uint32_t> tmpReg;
    MicroAPI::RegTensor<uint32_t> indexU32Reg;
    MicroAPI::RegTensor<uint32_t> dstReg;
    MicroAPI::MaskReg indexMask = MicroAPI::CreateMask<uint32_t, MicroAPI::MaskPattern::H>();
    MicroAPI::MaskReg dstMask;
    uint32_t sreg = static_cast<uint32_t>(u32Count);
    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        dstMask = MicroAPI::UpdateMask<uint32_t>(sreg);
        DataCopy(offsetReg, srcOffsetLocal + i * u64OffsetRepeatCount);
        // convert addr offset into B64 element index: divide by 8 (implemented by ShiftRight 3 bit)
        ShiftRights(indexReg, offsetReg, (int16_t)3, indexMask);
        // Consider every B64 element as two B32 elements
        // the index of lower 32-bit is 2*i, the index of higher 32-bit is 2*i+1
        Muls(indexReg, indexReg, 2, indexMask);
        Adds(oddIndexReg, indexReg, 1, indexMask);
        // Interleave the seperately calculated indices of the lower and higher 32-bit of every B64 element
        Interleave(indexU32Reg, tmpReg, indexReg, oddIndexReg);
        DataCopyGather(dstReg, (__ubuf__ uint32_t *)src + srcBaseIndex, indexU32Reg, dstMask);
        DataCopy((__ubuf__ uint32_t *)dst + i * u32RepeatCount, dstReg, dstMask);
    }
}

template <typename T>
__aicore__ inline void GatherApi2B64Impl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseIndex, const uint32_t count)
{
    uint32_t u32Count = static_cast<uint32_t>(count) * 2;
    uint32_t u64OffsetRepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(T));
    uint32_t u32RepeatCount = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(uint32_t));
    uint16_t repeatTime = CeilDivision(u32Count, u32RepeatCount);
    VfGatherApi2B64<T>(dst, src, srcOffsetLocal, srcBaseIndex, count, u32Count, u64OffsetRepeatCount,
        u32RepeatCount, repeatTime);
}

// Gather::Level 0 Normal mode
template <typename T>
__aicore__ inline void GatherImpl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcLength, const uint32_t srcBaseAddr, const uint64_t mask, const uint8_t repeatTime,
    const uint16_t &dstRepStride)
{
    static_assert(SupportBytes<T, 2, 4, 8>(), "Gather only support type b16/b32/b64 on current device");

    uint8_t newRepeatTimes = repeatTime;
    bool isNormalMode = !Internal::IsCounterMode();
    if (isNormalMode) {
        if constexpr (sizeof(T) == 2) {
            SetVectorMask<uint16_t>(mask);
        } else {
            SetVectorMask<uint32_t>(mask);
        }
    } else {
        newRepeatTimes = static_cast<uint8_t>(Internal::VecMicroGetRepeatTimes<T, false>(mask, repeatTime));
    }

    uint32_t srcBaseIndex;
    if constexpr (sizeof(T) == 2) {
        srcBaseIndex = srcBaseAddr / sizeof(T);
        if (isNormalMode) {
            GatherApi0B16Impl<T, true>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride, mask);
        } else {
            GatherApi0B16Impl<T, false>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride, mask);
        }
    } else if constexpr (sizeof(T) == 4) {
        srcBaseIndex = srcBaseAddr / sizeof(T);
        if (isNormalMode) {
            GatherApi0B32Impl<T, true>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride, mask);
        } else {
            GatherApi0B32Impl<T, false>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride, mask);
        }
    } else {
        srcBaseIndex = srcBaseAddr / sizeof(uint32_t);
        if (isNormalMode) {
            GatherApi0B64NormalImpl<T, true>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride,
                mask);
        } else {
            GatherApi0B64NormalImpl<T, false>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride,
                mask);
        }
    }
}

/* **************************************************************************************************
 * Gather                                             *
 * ************************************************************************************************* */
// Gather::Level 0 Bit-wise mode
template <typename T>
__aicore__ inline void GatherImpl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcLength, const uint32_t srcBaseAddr, const uint64_t mask[], const uint8_t repeatTime,
    const uint16_t &dstRepStride)
{
    static_assert(SupportBytes<T, 2, 4, 8>(), "Gather only support type b16/b32/b64 on current device");

    uint8_t newRepeatTimes = repeatTime;
    bool isNormalMode = !Internal::IsCounterMode();
    if (isNormalMode) {
        if constexpr (sizeof(T) == 2) {
            SetVectorMask<uint16_t>(mask[1], mask[0]);
        } else {
            SetVectorMask<uint32_t>(mask[1], mask[0]);
        }
    } else {
        newRepeatTimes = static_cast<uint8_t>(Internal::VecMicroGetRepeatTimes<T, false>(mask[0], repeatTime));
    }

    uint32_t srcBaseIndex;
    if constexpr (sizeof(T) == 2) {
        srcBaseIndex = srcBaseAddr / sizeof(T);
        if (isNormalMode) {
            GatherApi0B16Impl<T, true>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride, mask[0]);
        } else {
            GatherApi0B16Impl<T, false>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride, mask[0]);
        }
    } else if constexpr (sizeof(T) == 4) {
        srcBaseIndex = srcBaseAddr / sizeof(T);
        if (isNormalMode) {
            GatherApi0B32Impl<T, true>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride, mask[0]);
        } else {
            GatherApi0B32Impl<T, false>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride, mask[0]);
        }
    } else {
        srcBaseIndex = srcBaseAddr / sizeof(uint32_t);
        if (isNormalMode) {
            GatherApi0B64BitsImpl<T, true>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride,
                mask[0]);
        } else {
            GatherApi0B64BitsImpl<T, false>(dst, src, srcOffsetLocal, srcBaseIndex, newRepeatTimes, dstRepStride,
                mask[0]);
        }
    }
}

/* **************************************************************************************************
 * Gather                                             *
 * ************************************************************************************************* */
// Gather::Level 2 Count mode
template <typename T, bool isSetMask = true>
__aicore__ inline void GatherImpl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcBaseAddr, const uint32_t count)
{
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "Gather only support type b8/b16/b32/b64 on current device");

    uint32_t srcBaseIndex;
    if constexpr (sizeof(T) == 1) {
        srcBaseIndex = srcBaseAddr / sizeof(T);
        GatherApi2B8Impl(dst, src, srcOffsetLocal, srcBaseIndex, count);
    } else if constexpr (sizeof(T) == 2) {
        srcBaseIndex = srcBaseAddr / sizeof(T);
        GatherApi2B16Impl(dst, src, srcOffsetLocal, srcBaseIndex, count);
    } else if constexpr (sizeof(T) == 4) {
        srcBaseIndex = srcBaseAddr / sizeof(T);
        GatherApi2B32Impl(dst, src, srcOffsetLocal, srcBaseIndex, count);
    } else {
        srcBaseIndex = srcBaseAddr / sizeof(uint32_t);
        GatherApi2B64Impl(dst, src, srcOffsetLocal, srcBaseIndex, count);
    }
}

/* **************************************************************************************************
 * Gatherb                                             *
 * ************************************************************************************************* */
template <typename T>
__simd_vf__ inline void GatherbImpl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint8_t repeatTime, const GatherRepeatParams repeatParams)
{
    constexpr uint32_t oneBlkNum = static_cast<uint32_t>(ONE_BLK_SIZE / sizeof(T));
    MicroAPI::RegTensor<T> srcReg;
    MicroAPI::RegTensor<T> dstReg;
    MicroAPI::RegTensor<uint32_t> indexReg;
    MicroAPI::MaskReg fullPreg = MicroAPI::CreateMask<uint8_t>();

    constexpr uint32_t sregLower = static_cast<uint32_t>(VECTOR_REG_WIDTH / sizeof(T));
    for (uint16_t i = 0; i < static_cast<uint16_t>(repeatTime); ++i) {
        MicroAPI::DataCopy(indexReg, srcOffsetLocal + i * DEFAULT_BLK_NUM);
        if constexpr (sizeof(T) == 8) {
            // pg is the predicate for block index, and regarded as B32 format.
            // For convinience, we use fullPreg format to represent it.
            MicroAPI::DataCopyGatherB(dstReg, src, indexReg, fullPreg);
            MicroAPI::DataCopy<uint32_t, MicroAPI::DataCopyMode::DATA_BLOCK_COPY,
                MicroAPI::PostLiteral::POST_MODE_UPDATE>((__ubuf__ uint32_t *&)dst,
                (MicroAPI::RegTensor<uint32_t> &)dstReg, static_cast<uint32_t>(repeatParams.dstBlkStride),
                static_cast<uint32_t>(repeatParams.dstRepStride), fullPreg);
        } else {
            MicroAPI::DataCopyGatherB(dstReg, src, indexReg, fullPreg);
            MicroAPI::DataCopy<T, MicroAPI::DataCopyMode::DATA_BLOCK_COPY>(
                dst + i * repeatParams.dstRepStride * oneBlkNum, dstReg, repeatParams.dstBlkStride, fullPreg);
        }
    }
}

template <typename T>
__aicore__ inline void GatherbImpl(__ubuf__ T *dst, __ubuf__ T *src, __ubuf__ uint32_t *srcOffsetLocal,
    const uint32_t srcLength, const uint8_t repeatTime, const GatherRepeatParams &repeatParams)
{
    (void)srcLength;
    static_assert(SupportBytes<T, 1, 2, 4, 8>(), "Failed to check dtype in Gatherb, current api support dtype"
        "combination is src and dst both: uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, half, float, "
        "bfloat16_t, uint64_t, int64_t.");
    GatherbImpl(dst, src, srcOffsetLocal, repeatTime, repeatParams);
}
}
#endif // ASCENDC_MODULE_OPERATOR_VEC_GATHER_IMPL_H